using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;
using System.Threading;
#if TEST
using Xunit;
#endif

namespace Fadd.Commands.Net
{
    /// <summary>
    /// Transports commands in binary format.
    /// </summary>
    public class BinaryChannel : IDisposable
    {
        private const int BufferSize = 4192;
        private readonly BinaryFormatter _formatter = new BinaryFormatter();
        private readonly byte[] _inbuffer = new byte[BufferSize];
        private readonly Packet _outPacket = new Packet();
        private readonly Queue<byte[]> _sendQueue = new Queue<byte[]>();
        private Packet _inpacket = new Packet();
        private Timer _reconnectTimer;
        private IPEndPoint _remoteEndPoint;
        private bool _shouldReconnect;
        private Socket _socket;

        /// <summary>
        /// Invoked when we have recieved a packet from the remote end.
        /// </summary>
        public ObjectReceivedHandler ObjectReceived = delegate { };

        /// <summary>
        /// Initializes a new instance of the <see cref="BinaryChannel"/> class.
        /// </summary>
        /// <param name="socket">The socket.</param>
        public BinaryChannel(Socket socket)
        {
            _socket = socket;
            _remoteEndPoint = (IPEndPoint) _socket.RemoteEndPoint;
            _socket.BeginReceive(_inbuffer, 0, BufferSize, SocketFlags.None, OnReceiveComplete, null);
        }

        /// <summary>
        /// Initializes a new instance of the <see cref="BinaryChannel"/> class.
        /// </summary>
        public BinaryChannel()
        {
        }

        /// <summary>
        /// true if we should reconnect when getting disconnected.
        /// </summary>
        public bool ShouldReconnect
        {
            get { return _shouldReconnect; }
            set { _shouldReconnect = value; }
        }

        /// <summary>
        /// Invoked when channel is disconnected (except when <see cref="Close"/> are called).
        /// </summary>
        public event DisconnectedHandler Disconnected = delegate { };

        /// <summary>
        /// Releases unmanaged resources and performs other cleanup operations before the
        /// <see cref="BinaryChannel"/> is reclaimed by garbage collection.
        /// </summary>
        // //todo: Desctructor is called prematurely and I don't know why (tests fail because of it).
        ~BinaryChannel()
        {
            Dispose();
        }


        /// <summary>
        /// Sends the specified value.
        /// </summary>
        /// <param name="value">object to serialize and send.</param>
        public void Send(object value)
        {
            MemoryStream ms = new MemoryStream();
            _formatter.Serialize(ms, value);
            ms.Seek(0, SeekOrigin.Begin);
            byte[] buffer = new byte[ms.Length];
            ms.Read(buffer, 0, (int) ms.Length);
            Send(buffer);
        }

        ///<summary>
        /// Send a packet to the client.
        ///</summary>
        ///<param name="bytes"></param>
        public void Send(byte[] bytes)
        {
            lock (_sendQueue)
            {
                _sendQueue.Enqueue(bytes);
                Send();
            }
        }

        private void Send()
        {
            if (_socket == null || !_socket.Connected)
                return;

            lock (_sendQueue)
            {
                if (_outPacket.buffer != null)
                    return;
                if (_sendQueue.Count == 0)
                    return;

                _outPacket.buffer = _sendQueue.Dequeue();
            }


            _outPacket.index = 0;
            _outPacket.size = _outPacket.buffer.Length;

            byte[] header = BitConverter.GetBytes(_outPacket.size);

            int bytes = _socket.Send(header);
            if (bytes != header.Length)
            {
                Console.WriteLine("Header was not sent properly.");
                HandleDisconnect(SocketError.Success);
            }

            _socket.BeginSend(_outPacket.buffer, 0, _outPacket.size, SocketFlags.None, OnSendComplete, null);
        }

        private void OnSendComplete(IAsyncResult ar)
        {
            try
            {
                SocketError errorCode;
                int bytesSent = _socket.EndSend(ar, out errorCode);
                if (bytesSent == 0)
                {
                    HandleDisconnect(SocketError.ConnectionReset);
                    return;
                }
                _outPacket.index += bytesSent;
                if (_outPacket.index < _outPacket.size)
                {
                    _socket.BeginSend(_outPacket.buffer, _outPacket.index, _outPacket.size, SocketFlags.None,
                                      OnSendComplete, null);
                }
                else
                {
                    _outPacket.Clear();
                    Send();
                }
            }
            catch (SocketException err)
            {
                HandleDisconnect(err.SocketErrorCode);
            }
            catch (ObjectDisposedException)
            {
                HandleDisconnect(SocketError.ConnectionReset);
            }
        }


        private void OnReceiveComplete(IAsyncResult ar)
        {
            try
            {
                SocketError errorCode;
                int bytesRead = _socket.EndReceive(ar, out errorCode);
                if (errorCode != SocketError.Success)
                {
                    HandleDisconnect(errorCode);
                    return;
                }

                // Loop until all bytes hae been processed.
                int index = 0;
                while (true)
                {
                    index = ProcessInBuffer(_inbuffer, index, bytesRead);
                    if (index == bytesRead)
                        break;
                }

                _socket.BeginReceive(_inbuffer, 0, BufferSize, SocketFlags.None, OnReceiveComplete, null);
            }
            catch (ObjectDisposedException)
            {
                HandleDisconnect(SocketError.ConnectionReset);
            }
            catch (SocketException err)
            {
                HandleDisconnect(err.SocketErrorCode);
            }
        }

        /// <summary>
        /// Goes through all incoming bytes and creates a packet.
        /// One or more calls might be required to get a complete packet.
        /// </summary>
        /// <param name="inbuffer">buffer to process</param>
        /// <param name="index">where to start processing</param>
        /// <param name="count">total number of bytes in buffer.</param>
        /// <returns>number of bytes that are handled</returns>
        private int ProcessInBuffer(byte[] inbuffer, int index, int count)
        {
            int bytesLeft = count;

            // new packet, read header
            if (_inpacket.size == 0)
            {
                if (bytesLeft < 4)
                {
                    Console.WriteLine("Missing packet header.");
                    HandleDisconnect(SocketError.Success);
                    return 0;
                }

                bytesLeft -= 4;
                _inpacket = new Packet();
                _inpacket.size = BitConverter.ToInt32(inbuffer, 0);
                _inpacket.buffer = new byte[_inpacket.size];
                index += 4;
            }

            // copy object bytes.
            for (; bytesLeft > 0; ++index, --bytesLeft)
            {
                _inpacket.buffer[_inpacket.index++] = inbuffer[index];
                if (_inpacket.index == _inpacket.size)
                    break;
            }

            if (_inpacket.index == _inpacket.size)
            {
                OnBufferReceived(_inpacket.buffer);
                _inpacket.Clear();
            }

            return index < count ? index + 1 : count;
        }

        /// <summary>
        /// Called when a object buffer have been received completely.
        /// </summary>
        /// <param name="buffer">The buffer.</param>
        protected virtual void OnBufferReceived(byte[] buffer)
        {
            try
            {
                MemoryStream ms = new MemoryStream(buffer);
                object obj = _formatter.Deserialize(ms);
                ObjectReceived(this, new ObjectReceivedEventArgs(obj));
            }
            catch (SerializationException err)
            {
                Console.WriteLine(err);
            }
        }

#if TEST
        [Fact]
        private void Test2InPackets()
        {
            byte[] packet1 = TestCreatePacket("hello");
            byte[] packet2 = TestCreatePacket("world");
            byte[] packet = new byte[packet1.Length + packet2.Length];
            packet1.CopyTo(packet, 0);
            for (int i = 0; i < packet2.Length; ++i)
                packet[i + packet1.Length] = packet2[i];

            int objectCount = 0;
            ObjectReceived += delegate { ++objectCount; };

            int index = 0;
            index = ProcessInBuffer(packet, index, packet.Length);
            Assert.Equal(packet1.Length, index);
            index = ProcessInBuffer(packet, index, packet.Length - index);
            Assert.Equal(packet2.Length, index);
            Assert.Equal(2, objectCount);
        }
#endif

        byte[] TestCreatePacket(string text)
        {
            MemoryStream ms = new MemoryStream();

            BinaryFormatter formatter = new BinaryFormatter();
            formatter.Serialize(ms, text);
            ms.Seek(0, SeekOrigin.Begin);

            byte[] data = new byte[ms.Length];
            ms.Read(data, 0, (int) ms.Length);

            byte[] packet = new byte[data.Length + 4];
            BitConverter.GetBytes(data.Length).CopyTo(packet, 0);
            for (int i = 0; i < data.Length; ++i)
                packet[i + 4] = data[i];

            return packet;
        }

#if TEST
        [Fact]
        private void TestPartialInpacket()
        {
            byte[] all = TestCreatePacket("hello");

            int objectCount = 0;
            ObjectReceived += delegate { ++objectCount; };
            int index = ProcessInBuffer(all, 0, all.Length - 3);
            Assert.Equal(index, all.Length - 3);

            index = ProcessInBuffer(all, all.Length - 3, 3);
            Assert.Equal(index, 3);
            Assert.Equal(1, objectCount);
        }

        [Fact]
        private void TestSecondPartialInpacket()
        {
            BinaryChannel channel = new BinaryChannel();
            byte[] first = TestCreatePacket("hello");
            byte[] second = TestCreatePacket("world");
            byte[] all = new byte[first.Length + second.Length];
            for (int i = 0; i < first.Length; ++i)
                all[i] = first[i];
            for (int i = 0; i < second.Length; ++i)
                all[i + first.Length] = second[i];

            int objectCount = 0;
            channel.ObjectReceived += delegate
                                          {
                                              ++objectCount;
                                          };

            int index = channel.ProcessInBuffer(all, 0, first.Length + 4);
            Assert.Equal(first.Length, index);

            index = channel.ProcessInBuffer(all, index, all.Length - 5);
            Assert.Equal(all.Length - 5, index);

            index = channel.ProcessInBuffer(all, index, 5);
            Assert.Equal(5, index);

            Thread.Sleep(500);
            Assert.Equal(2, objectCount);
        }
#endif

        private void HandleDisconnect(SocketError code)
        {
            if (_shouldReconnect)
            {
                if (_reconnectTimer != null)
                    return;
                _reconnectTimer = new Timer(TryConnect, null, 15000, 15000);
            }
            _socket.Disconnect(true);
            _socket = null;
            Disconnected(this, new DisconnectedEventArgs(code));
        }

        private void TryConnect(object state)
        {
            try
            {
                _socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
                _socket.Connect(_remoteEndPoint);
                Timer timer = _reconnectTimer;
                _reconnectTimer = null;
                timer.Dispose();
            }
            catch (SocketException)
            {
            }
        }

        /// <summary>
        /// Closes this instance.
        /// </summary>
        public void Close()
        {
            _shouldReconnect = false;
            _socket.Close();
            _socket = null;
        }

        /// <summary>
        /// Connect to an endpoint.
        /// </summary>
        /// <param name="endPoint">Where to connect</param>
        /// <exception cref="SocketException">if connection fails.</exception>
        public void Open(IPEndPoint endPoint)
        {
            if (_socket == null)
                _socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);

            _remoteEndPoint = endPoint;
            _socket.Connect(endPoint);
            _socket.BeginReceive(_inbuffer, 0, BufferSize, SocketFlags.None, OnReceiveComplete, null);
        }

        #region Implementation of IDisposable

        /// <summary>
        /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
        /// </summary>
        /// <filterpriority>2</filterpriority>
        public void Dispose()
        {
            _shouldReconnect = false;
            try
            {
                if (_socket != null)
                {
                    _socket.Close();
                    _socket = null;
                }
            }
            catch (SocketException)
            {
            }
            if (_reconnectTimer != null)
            {
                _reconnectTimer.Dispose();
                _reconnectTimer = null;
            }

            _sendQueue.Clear();
            Disconnected = null;
        }

        #endregion

        #region Nested type: Packet

        private class Packet
        {
            public byte[] buffer;
            public int index;
            public int size;

            public void Clear()
            {
                size = 0;
                index = 0;
                buffer = null;
            }
        }

        #endregion
    }

    /// <summary>
    /// Event args for <see cref="ObjectReceivedHandler"/>
    /// </summary>
    public class ObjectReceivedEventArgs : EventArgs
    {
        private readonly object _object;

        /// <summary>
        /// Initializes a new instance of the <see cref="ObjectReceivedEventArgs"/> class.
        /// </summary>
        /// <param name="value">object received from remote end.</param>
        public ObjectReceivedEventArgs(object value)
        {
            Check.Require(value, "value");
            _object = value;
        }

        /// <summary>
        /// Bytes received from remote end.
        /// </summary>
        public object Object
        {
            get { return _object; }
        }
    }

    /// <summary>
    /// Invoked when an object have been received from the remote end.
    /// </summary>
    /// <param name="source">Client that received the object.</param>
    /// <param name="args">object received.</param>
    public delegate void ObjectReceivedHandler(object source, ObjectReceivedEventArgs args);
}